{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Load and minibatch MNIST data\n",
    "(c) Deniz Yuret, 2019\n",
    "* Objective: Load the [MNIST](http://yann.lecun.com/exdb/mnist) dataset, convert into Julia arrays, split into minibatches using Knet's [minibatch](http://denizyuret.github.io/Knet.jl/latest/reference/#Knet.minibatch) function and  [Data](https://github.com/denizyuret/Knet.jl/blob/master/src/data.jl) iterator type.\n",
    "* Prerequisites: [Julia arrays](https://docs.julialang.org/en/v1/manual/arrays)\n",
    "* New functions: [dir](http://denizyuret.github.io/Knet.jl/latest/reference/#Knet.dir), [minibatch, Data](http://denizyuret.github.io/Knet.jl/latest/reference/#Knet.minibatch), [mnist, mnistview](https://github.com/denizyuret/Knet.jl/blob/master/data/mnist.jl)\n",
    "\n",
    "In the next few notebooks, we build classification models for the MNIST handwritten digit recognition dataset. MNIST has 60000 training and 10000 test examples. Each input x consists of 784 pixels representing a 28x28 image. The corresponding output indicates the identity of the digit 0..9.\n",
    "\n",
    "![](http://yann.lecun.com/exdb/lenet/gifs/asamples.gif \"MNIST\")\n",
    "\n",
    "[image source](http://yann.lecun.com/exdb/lenet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# Load packages, import symbols\n",
    "using Pkg; for p in (\"Knet\",\"Images\",\"ImageMagick\"); haskey(Pkg.installed(),p) || Pkg.add(p); end\n",
    "using Knet: Knet, dir, minibatch, Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Loading MNIST...\n",
      "└ @ Main /home/gridsan/dyuret/.julia/dev/Knet/data/mnist.jl:33\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "28×28×1×60000 Array{Float32,4}\n",
      "60000-element Array{UInt8,1}\n",
      "28×28×1×10000 Array{Float32,4}\n",
      "10000-element Array{UInt8,1}\n"
     ]
    }
   ],
   "source": [
    "# This loads the MNIST handwritten digit recognition dataset:\n",
    "include(Knet.dir(\"data\",\"mnist.jl\")) # Knet.dir constructs a path relative to Knet root\n",
    "xtrn,ytrn,xtst,ytst = mnist()        # mnist() loads MNIST data and converts into Julia arrays\n",
    "println.(summary.((xtrn,ytrn,xtst,ytst)));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARgAAAA4CAAAAAADPrjSAAAABGdBTUEAALGPC/xhBQAAACBjSFJNAAB6JgAAgIQAAPoAAACA6AAAdTAAAOpgAAA6mAAAF3CculE8AAAAAmJLR0QA/4ePzL8AAAUASURBVGje7dppbBVVGIDhp8WCFhJCsSgaA1ojCVqxgruIW0CRIkIRIv5AEyDGLaK4gEZR0aQmGleMJu5GiYgGAVGDETXiUoOKWDXWWFciKCooRODWHw71bud27qURNPP+6Zyv3/nmTPvec87MXBISEhISEnYVyjqz2BX2cKgGzLXC4zv72naI8p09gF2VTjRmnoa0VotTfV1ClYN86lJ3x87v7jbTvG+81k79wyTGBNitswpt9+VTLzlAvRrnuqWEOnVSvisifx9TpAw2yr2x+xxugf4ZkeGafZMRSYwJ0EnGDHEWVhttnY26etsgVSVVOszvFsTOrvZICecYoVtWZLTzTcyIJMYEyDCmwRTf2+xJa3xRVJm+yqw2wg/gCgOxuITh1LrYY7GzLzHGkdHxCcp96PVYlzwyJ9Zkuu5+T4skxgTIMKYxmqun2WB1VuK3GjUFy7zgQBv8HLUmqChxOANUmhc7+w6p9uOxxmo1wfsd9jrJMRqzYlUGqkyMiUOGMVMM8omB6pzoaN/YD2y1Vl98XcAYafvOGQ7CO94pYThXai14lnSWpP1Xf7JRP/t7V5cOetV6SkvODmt0Tl5iTIAMY5ZZhqXopU6TI8Bmn2tW5ctYBUe5UVc/usYfRQ+mvyE+z/ikhxlmgFQ0x9zvZb862SxcYG7BfrN0d7qNGbEqw9Jmq79JjAkQ2Pmu9yqWRa1xelnl6VgFh+iKeZaXMJhhWBsrs7+n7QlaPWu2P9BqqmqNdnePLYF+DUb6wntZ0VlSXvNLRiwxJkCMe6U+7lPuxvZdSiGeNxyPubakwdSSs8PIT0Xky3ITrYtirW51u0qNFmoJ9BuvMmcO6m+SbeZkWZYYEyCGMReqtt5nMYr1daxu1rk5a9aPxzHOs9IrRfRocn67L7DQpGglzU9PR+O+rOhUe2r2alY0MSZAh8Yc52qM8XGMYs/qjSeCn/DCnKLKUptjZpfjqKxYmXLlmO3cvH262TfP2lpDnqtLjAnQoTEjVVhmRYxSox2O11xf4lAGaTM/Zu60nJ0q1KuTkgqOYIMP1KrKWF/7aMCbObmJMQE6MGYPp/nT9cGd5D/0NlMFPihpRWJvQ33muZjZ9TmRagPNBGuDo92kxTiL3R61D1Gjnza05eQmxgTowJgZ6iz1VoxClzsCz5c8w0zWx4s7cCGzXAi+MrnAG9AblDnDU1FrnbZoB/1wTmZiTICCxpzhOr+5KVah6eCiEmcY+mF9yZexxIDoqNkbBfKana1OTdSaj0dNwqaczMSYAAWM6e0uXSyJtYfZTlX7ivCrLSr0RC+XgW2uKvBUrx6LYp+nTDlOx4P6orx9XzOqw74rrUxr/f1cstaqrKzEmABBY7pYan8triuq3EftR8/4wV4mZPx2jTmBfkPtVdR55mrEIikiV7Y//S2WMmXk+JIYEyRoTI3BmB77TnmJMzPa46OfW6WwUJN8dyTbGaOLlUU8J15ghuqMyFrNpkbvzouhLc+ul8SYIAFj+nkZM4pYJ8a6MnpjfXA0szzkKyzQ3GHfSiMx37bYZ2s10RiXpkXmFPGNqnR2J+8zoMSYAIFvbc5xDY6M/R55x6iw3I/OKfrd5WmmqrfQA8p8UtJ3RFljNze5MyeeGBMgrzFDLdbDv2fMzuQFd+S8ISAxJkjeVel4PdBS8p3yf4n6QDwxJkBw5/uhU2K9rf6/khiTkNAp/AWFPiBLadIx6wAAAABJRU5ErkJggg==",
      "text/plain": [
       "28×140 Array{Gray{Float32},2} with eltype Gray{Float32}:\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)  …  Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)  …  Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)  …  Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " ⋮                                       ⋱                    \n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)  …  Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)  …  Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)\n",
       " Gray{Float32}(0.0)  Gray{Float32}(0.0)     Gray{Float32}(0.0)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# mnistview(x,i) converts the i'th instance in x into an image\n",
    "# Here is the first five images from the test set:\n",
    "using Images\n",
    "hcat([mnistview(xtst,i) for i=1:5]...)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[7, 2, 1, 10, 4]\n"
     ]
    }
   ],
   "source": [
    "# Here are their labels (10 is used to represent 0)\n",
    "println(Int.(ytst[1:5]));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Data{Tuple{Array{Float32,4},Array{UInt8,1}}}(Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], UInt8[0x07 0x02 … 0x05 0x06], 100, 10000, false, 9901, 1:10000, false, (28, 28, 1, 10000), (10000,), Array{Float32,4}, Array{UInt8,1})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# `minibatch` splits the data tensors to small chunks called minibatches.\n",
    "# It returns a Knet.Data struct: an iterator of (x,y) pairs.\n",
    "dtrn = minibatch(xtrn,ytrn,100)\n",
    "dtst = minibatch(xtst,ytst,100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "28×28×1×100 Array{Float32,4}\n",
      "100-element Array{UInt8,1}\n"
     ]
    }
   ],
   "source": [
    "# Each minibatch is an (x,y) pair where x is 100 (28x28x1) images and y are the corresponding 100 labels.\n",
    "# Here is the first minibatch in the test set:\n",
    "(x,y) = first(dtst)\n",
    "println.(summary.((x,y)));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "600"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Iterators can be used in for loops, e.g. `for (x,y) in dtrn`\n",
    "# dtrn generates 600 minibatches of 100 images (total 60000)\n",
    "# dtst generates 100 minibatches of 100 images (total 10000)\n",
    "n = 0\n",
    "for (x,y) in dtrn\n",
    "    n += 1\n",
    "end\n",
    "n"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "julia.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Julia 1.0.3",
   "language": "julia",
   "name": "julia-1.0"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.0.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
